{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "This simple notebook demonstrates the workflow of using the TensorFlow converter.\n",
    "\"\"\"\n",
    "from __future__ import print_function\n",
    "import numpy as np\n",
    "from tensorflow.python.tools.freeze_graph import freeze_graph\n",
    "\n",
    "import tfcoreml\n",
    "import linear_mnist_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Step 0: Before you run this notebook, view run the example script linear_mnist_train.py\n",
    "to get a trained TensorFlow network.\n",
    "This may take a few minutes.\n",
    "\"\"\"\n",
    "linear_mnist_train.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Step 1: \"Freeze\" your tensorflow model - convert your TF model into a stand-alone graph definition file\n",
    "Inputs: \n",
    "(1) TensorFlow code\n",
    "(2) trained weights in a checkpoint file\n",
    "(3) The output tensors' name you want to use in inference\n",
    "(4) [Optional] Input tensors' name to TF model\n",
    "Outputs: \n",
    "(1) A frozen TensorFlow GraphDef, with trained weights frozen into it\n",
    "\"\"\"\n",
    "\n",
    "# Provide these to run freeze_graph:\n",
    "# Graph definition file, stored as protobuf TEXT\n",
    "graph_def_file = './model.pbtxt'\n",
    "# Trained model's checkpoint name\n",
    "checkpoint_file = './checkpoints/model.ckpt'\n",
    "# Frozen model's output name\n",
    "frozen_model_file = './frozen_model.pb'\n",
    "# Output nodes. If there're multiple output ops, use comma separated string, e.g. \"out1,out2\".\n",
    "output_node_names = 'Softmax' \n",
    "\n",
    "\n",
    "# Call freeze graph\n",
    "freeze_graph(input_graph=graph_def_file,\n",
    "             input_saver=\"\",\n",
    "             input_binary=False,\n",
    "             input_checkpoint=checkpoint_file,\n",
    "             output_node_names=output_node_names,\n",
    "             restore_op_name=\"save/restore_all\",\n",
    "             filename_tensor_name=\"save/Const:0\",\n",
    "             output_graph=frozen_model_file,\n",
    "             clear_devices=True,\n",
    "             initializer_nodes=\"\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Step 2: Call converter\n",
    "\"\"\"\n",
    "\n",
    "# Provide these inputs in addition to inputs in Step 1\n",
    "# A dictionary of input tensors' name and shape (with batch)\n",
    "input_tensor_shapes = {\"Placeholder:0\":[1,784]} # batch size is 1\n",
    "# Output CoreML model path\n",
    "coreml_model_file = './model.mlmodel'\n",
    "output_tensor_names = ['Softmax:0']\n",
    "\n",
    "\n",
    "# Call the converter\n",
    "coreml_model = tfcoreml.convert(\n",
    "        tf_model_path=frozen_model_file, \n",
    "        mlmodel_path=coreml_model_file, \n",
    "        input_name_shape_dict=input_tensor_shapes,\n",
    "        output_feature_names=output_tensor_names)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Step 3: Run the converted model\n",
    "\"\"\"\n",
    "\n",
    "# Provide CoreML model with a dictionary as input. Change ':0' to '__0'\n",
    "# as Swift / Objective-C code generation do not allow colons in variable names\n",
    "np.random.seed(100)\n",
    "coreml_inputs = {'Placeholder__0': np.random.rand(1,1,784)} # (sequence_length=1,batch=1,channels=784)\n",
    "coreml_output = coreml_model.predict(coreml_inputs, useCPUOnly=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(coreml_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
